{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4fee8a1a-242b-4853-9be2-46c19482eaf7",
   "metadata": {},
   "source": [
    "# Lesson 6: Expanding the SEC Knowledge Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8609a2b-bdb3-428c-8c09-b9443f6cda3c",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fd4a6180; padding:15px; margin-left:20px\"> <b>Note:</b> This notebook takes about 30 seconds to be ready to use. Please wait until the \"Kernel starting, please wait...\" message clears from the top of the notebook before running any cells. You may start the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c376494-9d29-48fb-8333-24d7e6ab1e85",
   "metadata": {},
   "source": [
    "### Import packages and set up Neo4j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fdde3df-66fc-42bc-80e8-fc8991e2775e",
   "metadata": {
    "height": 284
   },
   "outputs": [],
   "source": [
    "from dotenv import load_dotenv\n",
    "import os\n",
    "import textwrap\n",
    "\n",
    "# Langchain\n",
    "from langchain_community.graphs import Neo4jGraph\n",
    "from langchain_community.vectorstores import Neo4jVector\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.chains import RetrievalQAWithSourcesChain\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "# Warning control\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c6884bf-0ccd-4be8-9843-5ad6a7d13279",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "# Load from environment\n",
    "load_dotenv('.env', override=True)\n",
    "NEO4J_URI = os.getenv('NEO4J_URI')\n",
    "NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')\n",
    "NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')\n",
    "NEO4J_DATABASE = os.getenv('NEO4J_DATABASE') or 'neo4j'\n",
    "\n",
    "# Global constants\n",
    "VECTOR_INDEX_NAME = 'form_10k_chunks'\n",
    "VECTOR_NODE_LABEL = 'Chunk'\n",
    "VECTOR_SOURCE_PROPERTY = 'text'\n",
    "VECTOR_EMBEDDING_PROPERTY = 'textEmbedding'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c068179f-30ad-427c-a266-e4ed79494356",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "kg = Neo4jGraph(\n",
    "    url=NEO4J_URI, \n",
    "    username=NEO4J_USERNAME, \n",
    "    password=NEO4J_PASSWORD, \n",
    "    database=NEO4J_DATABASE\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b9dda8-0f99-43bd-8104-d29829862465",
   "metadata": {},
   "source": [
    "### Read the collection of Form 13s\n",
    "- Investment management firms must report on their investments in companies to the SEC by filing a document called **Form 13**\n",
    "- You'll load a collection of Form 13 for managers that have invested in NetApp\n",
    "- You can check out the CSV file by navigating to the data directory using the File menu at the top of the notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c741b665-aa5d-4650-8df6-9c43febbcdf5",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "all_form13s = []\n",
    "\n",
    "with open('./data/form13.csv', mode='r') as csv_file:\n",
    "    csv_reader = csv.DictReader(csv_file)\n",
    "    for row in csv_reader: # each row will be a dictionary\n",
    "      all_form13s.append(row)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "344c1290-5191-4a2a-9f3c-d714ae671362",
   "metadata": {},
   "source": [
    "- Look at the contents of the first 5 Form 13s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b09890e0-6e7b-445c-84dc-c988ec473fae",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "all_form13s[0:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b443488-4815-4ed2-9927-132e3a625bfe",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "len(all_form13s)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd216140-a6ef-4226-b73a-040fde03dc84",
   "metadata": {},
   "source": [
    "### Create company nodes in the graph\n",
    "- Use the companies identified in the Form 13s to create `Company` nodes\n",
    "- For now, there is only one company - NetApp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e63a8f2-acbd-4ce1-b433-83cdbe0375bf",
   "metadata": {
    "height": 267
   },
   "outputs": [],
   "source": [
    "# work with just the first form fow now\n",
    "first_form13 = all_form13s[0]\n",
    "\n",
    "cypher = \"\"\"\n",
    "MERGE (com:Company {cusip6: $cusip6})\n",
    "  ON CREATE\n",
    "    SET com.companyName = $companyName,\n",
    "        com.cusip = $cusip\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={\n",
    "    'cusip6':first_form13['cusip6'], \n",
    "    'companyName':first_form13['companyName'], \n",
    "    'cusip':first_form13['cusip'] \n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f48e6d64-b920-4738-8f4c-ce632dad923d",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "MATCH (com:Company)\n",
    "RETURN com LIMIT 1\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f472c370-8e7d-4970-ab18-dbfdf5d9a5d7",
   "metadata": {},
   "source": [
    "- Update the company name to match Form 10-K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de2b3422-6b82-4adc-8e18-1c3d195472b1",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (com:Company), (form:Form)\n",
    "    WHERE com.cusip6 = form.cusip6\n",
    "  RETURN com.companyName, form.names\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64422771-f6c4-4010-9691-eb42dc91df78",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (com:Company), (form:Form)\n",
    "    WHERE com.cusip6 = form.cusip6\n",
    "  SET com.names = form.names\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7c24776-4951-423e-a6d6-054c4ea701d5",
   "metadata": {},
   "source": [
    "- Create a `FILED` relationship between the company and the Form-10K node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b24cc9f-3b33-4e9d-aef3-c7fc9d79c10f",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "  MATCH (com:Company), (form:Form)\n",
    "    WHERE com.cusip6 = form.cusip6\n",
    "  MERGE (com)-[:FILED]->(form)\n",
    "\"\"\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "748c204e-fce1-402e-a5ce-6e48dda8d011",
   "metadata": {},
   "source": [
    "### Create manager nodes\n",
    "- Create a `manager` node for companies that have filed a Form 13 to report their investment in NetApp\n",
    "- Start with the single manager who filed the first Form 13 in the list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd76c0f5-4ed6-43c3-8cc7-39ae1daa7f6e",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MERGE (mgr:Manager {managerCik: $managerParam.managerCik})\n",
    "    ON CREATE\n",
    "        SET mgr.managerName = $managerParam.managerName,\n",
    "            mgr.managerAddress = $managerParam.managerAddress\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'managerParam': first_form13})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b88849f6-5965-48ca-9067-9882bcd5f816",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "  MATCH (mgr:Manager)\n",
    "  RETURN mgr LIMIT 1\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc86bd60-be9f-455f-9cea-c3fbd018d212",
   "metadata": {},
   "source": [
    "- Create a uniquness constraint to avoid duplicate managers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b39079fc-d49d-490d-ae33-189e3a56142e",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "CREATE CONSTRAINT unique_manager \n",
    "  IF NOT EXISTS\n",
    "  FOR (n:Manager) \n",
    "  REQUIRE n.managerCik IS UNIQUE\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39fd7e0d-d09f-46bb-b2c4-f952d68636c6",
   "metadata": {},
   "source": [
    "- Create a fulltext index of manager names to enable text search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5db369-9a6d-4cae-af40-5a4efee47f01",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "CREATE FULLTEXT INDEX fullTextManagerNames\n",
    "  IF NOT EXISTS\n",
    "  FOR (mgr:Manager) \n",
    "  ON EACH [mgr.managerName]\n",
    "\"\"\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd04b473-2ded-4efe-8b13-ba229c9d046f",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "  CALL db.index.fulltext.queryNodes(\"fullTextManagerNames\", \n",
    "      \"royal bank\") YIELD node, score\n",
    "  RETURN node.managerName, score\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df47dbc2-4929-4cda-8671-cbd727472bb3",
   "metadata": {},
   "source": [
    "- Create nodes for all companies that filed a Form 13"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "909eb642-92a9-4da4-8092-f3d0858243a0",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MERGE (mgr:Manager {managerCik: $managerParam.managerCik})\n",
    "    ON CREATE\n",
    "        SET mgr.managerName = $managerParam.managerName,\n",
    "            mgr.managerAddress = $managerParam.managerAddress\n",
    "\"\"\"\n",
    "# loop through all Form 13s\n",
    "for form13 in all_form13s:\n",
    "  kg.query(cypher, params={'managerParam': form13 })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39b64c10-1baf-4f67-92a6-986a5040c010",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "    MATCH (mgr:Manager) \n",
    "    RETURN count(mgr)\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a4d1b69-a78c-4bfe-9a71-23b889d3ad76",
   "metadata": {},
   "source": [
    "### Create relationships between managers and companies\n",
    "- Match companies with managers based on data in the Form 13\n",
    "- Create an `OWNS_STOCK_IN` relationship between the manager and the company\n",
    "- Start with the single manager who filed the first Form 13 in the list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc00d1a0-62f8-4e8a-8422-b6b054537b33",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (mgr:Manager {managerCik: $investmentParam.managerCik}), \n",
    "        (com:Company {cusip6: $investmentParam.cusip6})\n",
    "  RETURN mgr.managerName, com.companyName, $investmentParam as investment\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={ \n",
    "    'investmentParam': first_form13 \n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de36066-caa3-43e1-9bf8-3cfda9a1a9eb",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "MATCH (mgr:Manager {managerCik: $ownsParam.managerCik}), \n",
    "        (com:Company {cusip6: $ownsParam.cusip6})\n",
    "MERGE (mgr)-[owns:OWNS_STOCK_IN { \n",
    "    reportCalendarOrQuarter: $ownsParam.reportCalendarOrQuarter\n",
    "}]->(com)\n",
    "ON CREATE\n",
    "    SET owns.value  = toFloat($ownsParam.value), \n",
    "        owns.shares = toInteger($ownsParam.shares)\n",
    "RETURN mgr.managerName, owns.reportCalendarOrQuarter, com.companyName\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={ 'ownsParam': first_form13 })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6eca492-60ae-49b8-910e-fc462e0945fa",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "kg.query(\"\"\"\n",
    "MATCH (mgr:Manager {managerCik: $ownsParam.managerCik})\n",
    "-[owns:OWNS_STOCK_IN]->\n",
    "        (com:Company {cusip6: $ownsParam.cusip6})\n",
    "RETURN owns { .shares, .value }\n",
    "\"\"\", params={ 'ownsParam': first_form13 })"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "228516cd-9098-4ac5-89ed-868d2cb1d2ef",
   "metadata": {},
   "source": [
    "- Create relationships between all of the managers who filed Form 13s and the company"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e54d0d1d-ddfe-410e-a1e2-9c652b248706",
   "metadata": {
    "height": 267
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "MATCH (mgr:Manager {managerCik: $ownsParam.managerCik}), \n",
    "        (com:Company {cusip6: $ownsParam.cusip6})\n",
    "MERGE (mgr)-[owns:OWNS_STOCK_IN { \n",
    "    reportCalendarOrQuarter: $ownsParam.reportCalendarOrQuarter \n",
    "    }]->(com)\n",
    "  ON CREATE\n",
    "    SET owns.value  = toFloat($ownsParam.value), \n",
    "        owns.shares = toInteger($ownsParam.shares)\n",
    "\"\"\"\n",
    "\n",
    "#loop through all Form 13s\n",
    "for form13 in all_form13s:\n",
    "  kg.query(cypher, params={'ownsParam': form13 })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcacbf4a-a588-46a8-8265-bbcd25b4312b",
   "metadata": {
    "height": 114
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "  MATCH (:Manager)-[owns:OWNS_STOCK_IN]->(:Company)\n",
    "  RETURN count(owns) as investments\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3db81fbe-f0c1-47cb-a8f3-7bb13e7db01d",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "kg.refresh_schema()\n",
    "print(textwrap.fill(kg.schema, 60))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80aab435-67f1-4d8f-a578-a099f25def94",
   "metadata": {},
   "source": [
    "### Determine the number of investors\n",
    "- Start by finding a form 10-K chunk, and save to use in subsequent queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac455645-f98c-4e91-ad8f-7ea555781613",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MATCH (chunk:Chunk)\n",
    "    RETURN chunk.chunkId as chunkId LIMIT 1\n",
    "    \"\"\"\n",
    "\n",
    "chunk_rows = kg.query(cypher)\n",
    "print(chunk_rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d748b6e4-80c4-45d1-9afb-34381331f605",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "chunk_first_row = chunk_rows[0]\n",
    "print(chunk_first_row)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09aeb04f-98ac-4939-9ad4-4caa85ef97af",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "ref_chunk_id = chunk_first_row['chunkId']\n",
    "ref_chunk_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "776dbdac-23c8-4a39-954e-4eed8f2525a4",
   "metadata": {},
   "source": [
    "- Build up path from Form 10-K chunk to companies and managers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d40f129d-750d-4f7f-be88-eb45281c534b",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MATCH (:Chunk {chunkId: $chunkIdParam})-[:PART_OF]->(f:Form)\n",
    "    RETURN f.source\n",
    "    \"\"\"\n",
    "\n",
    "kg.query(cypher, params={'chunkIdParam': ref_chunk_id})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "342246fa-715f-4078-90db-200d33232046",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "MATCH (:Chunk {chunkId: $chunkIdParam})-[:PART_OF]->(f:Form),\n",
    "    (com:Company)-[:FILED]->(f)\n",
    "RETURN com.companyName as name\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={'chunkIdParam': ref_chunk_id})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e2905fc-0d74-43e0-87c7-62d181b1703d",
   "metadata": {
    "height": 216
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "MATCH (:Chunk {chunkId: $chunkIdParam})-[:PART_OF]->(f:Form),\n",
    "        (com:Company)-[:FILED]->(f),\n",
    "        (mgr:Manager)-[:OWNS_STOCK_IN]->(com)\n",
    "RETURN com.companyName, \n",
    "        count(mgr.managerName) as numberOfinvestors \n",
    "LIMIT 1\n",
    "\"\"\"\n",
    "\n",
    "kg.query(cypher, params={\n",
    "    'chunkIdParam': ref_chunk_id\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6414d691-ac43-4405-bb34-9d9a00ca7517",
   "metadata": {},
   "source": [
    "### Use queries to build additional context for LLM\n",
    "- Create sentences that indicate how much stock a manager has invested in a company"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37b91f15-2ea2-42f3-8847-04b9d7b07c07",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "cypher = \"\"\"\n",
    "    MATCH (:Chunk {chunkId: $chunkIdParam})-[:PART_OF]->(f:Form),\n",
    "        (com:Company)-[:FILED]->(f),\n",
    "        (mgr:Manager)-[owns:OWNS_STOCK_IN]->(com)\n",
    "    RETURN mgr.managerName + \" owns \" + owns.shares + \n",
    "        \" shares of \" + com.companyName + \n",
    "        \" at a value of $\" + \n",
    "        apoc.number.format(toInteger(owns.value)) AS text\n",
    "    LIMIT 10\n",
    "    \"\"\"\n",
    "kg.query(cypher, params={\n",
    "    'chunkIdParam': ref_chunk_id\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b097f06-8189-4a4a-ae62-05be79416433",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "results = kg.query(cypher, params={\n",
    "    'chunkIdParam': ref_chunk_id\n",
    "})\n",
    "print(textwrap.fill(results[0]['text'], 60))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d53fdf0-45e1-4b9b-aea0-20207226083d",
   "metadata": {},
   "source": [
    "- Create a plain Question Answer chain\n",
    "- Similarity search only, no augmentation by Cypher Query "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91118de7-4e4a-43a8-894b-02bb3373f489",
   "metadata": {
    "height": 335
   },
   "outputs": [],
   "source": [
    "vector_store = Neo4jVector.from_existing_graph(\n",
    "    embedding=OpenAIEmbeddings(),\n",
    "    url=NEO4J_URI,\n",
    "    username=NEO4J_USERNAME,\n",
    "    password=NEO4J_PASSWORD,\n",
    "    index_name=VECTOR_INDEX_NAME,\n",
    "    node_label=VECTOR_NODE_LABEL,\n",
    "    text_node_properties=[VECTOR_SOURCE_PROPERTY],\n",
    "    embedding_node_property=VECTOR_EMBEDDING_PROPERTY,\n",
    ")\n",
    "# Create a retriever from the vector store\n",
    "retriever = vector_store.as_retriever()\n",
    "\n",
    "# Create a chatbot Question & Answer chain from the retriever\n",
    "plain_chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(temperature=0), \n",
    "    chain_type=\"stuff\", \n",
    "    retriever=retriever\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e45bc34-b611-46bc-8da1-8d0df506dbd4",
   "metadata": {},
   "source": [
    "- Create a second QA chain\n",
    "- Augment similarity search using sentences found by the investment query above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17067dee-df9e-4c9c-b2c8-b566aba4ccaf",
   "metadata": {
    "height": 352
   },
   "outputs": [],
   "source": [
    "investment_retrieval_query = \"\"\"\n",
    "MATCH (node)-[:PART_OF]->(f:Form),\n",
    "    (f)<-[:FILED]-(com:Company),\n",
    "    (com)<-[owns:OWNS_STOCK_IN]-(mgr:Manager)\n",
    "WITH node, score, mgr, owns, com \n",
    "    ORDER BY owns.shares DESC LIMIT 10\n",
    "WITH collect (\n",
    "    mgr.managerName + \n",
    "    \" owns \" + owns.shares + \n",
    "    \" shares in \" + com.companyName + \n",
    "    \" at a value of $\" + \n",
    "    apoc.number.format(toInteger(owns.value)) + \".\" \n",
    ") AS investment_statements, node, score\n",
    "RETURN apoc.text.join(investment_statements, \"\\n\") + \n",
    "    \"\\n\" + node.text AS text,\n",
    "    score,\n",
    "    { \n",
    "      source: node.source\n",
    "    } as metadata\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc9217cc-8014-43d6-97b2-63b4305e30bb",
   "metadata": {
    "height": 369
   },
   "outputs": [],
   "source": [
    "vector_store_with_investment = Neo4jVector.from_existing_index(\n",
    "    OpenAIEmbeddings(),\n",
    "    url=NEO4J_URI,\n",
    "    username=NEO4J_USERNAME,\n",
    "    password=NEO4J_PASSWORD,\n",
    "    database=\"neo4j\",\n",
    "    index_name=VECTOR_INDEX_NAME,\n",
    "    text_node_property=VECTOR_SOURCE_PROPERTY,\n",
    "    retrieval_query=investment_retrieval_query,\n",
    ")\n",
    "\n",
    "# Create a retriever from the vector store\n",
    "retriever_with_investments = vector_store_with_investment.as_retriever()\n",
    "\n",
    "# Create a chatbot Question & Answer chain from the retriever\n",
    "investment_chain = RetrievalQAWithSourcesChain.from_chain_type(\n",
    "    ChatOpenAI(temperature=0), \n",
    "    chain_type=\"stuff\", \n",
    "    retriever=retriever_with_investments\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f76e2e5-f3ea-4add-85b1-1cabc92819ef",
   "metadata": {},
   "source": [
    "- Compare the outputs!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "671a54ec-f7e8-4072-b275-414e5d3c1e6e",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "question = \"In a single sentence, tell me about Netapp.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d441aba-971d-4d71-9add-6e2321b40d2a",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "plain_chain(\n",
    "    {\"question\": question},\n",
    "    return_only_outputs=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72816577-b60e-4368-9c8f-7ba54866d79b",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "investment_chain(\n",
    "    {\"question\": question},\n",
    "    return_only_outputs=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a29ea42c-9aef-44aa-b1e2-37edce12aad2",
   "metadata": {},
   "source": [
    "- The LLM didn't make use of the investor information since the question didn't ask about investors\n",
    "- Change the question and ask again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "106dfca9-869d-4f55-8d23-1090c496bf92",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "question = \"In a single sentence, tell me about Netapp investors.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5e286f5-b6ec-44a7-8f55-4f35771fbfaa",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "plain_chain(\n",
    "    {\"question\": question},\n",
    "    return_only_outputs=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b50849ac-d68a-410c-afda-3613dd195a19",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "investment_chain(\n",
    "    {\"question\": question},\n",
    "    return_only_outputs=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f02b3fdc-f679-42ea-b90a-9792766ea04a",
   "metadata": {},
   "source": [
    "### Try for yourself\n",
    "\n",
    "- Try changing the query above to retrieve other information\n",
    "- Try asking different questions\n",
    "- Note, if you change the Cypher query, you'll need to reset the retriever and QA chain"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
